package vpcserver

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"log"
	"net"
	"strconv"
	"sync"
	"time"

	"github.com/patrickmn/go-cache"
)

const (
	DefaultTcpBufSize           = 128     //默认的Tcp缓冲区大小
	DefaultMaxIdealTimeOut      = 60 * 24 //24小时
	DefaultAssetMaxIdealTimeOut = 3
)

type VpcServer struct {
	ticketsCache *cache.Cache // 票据缓存
	config       *ServerConfig
	vpcConnMap   sync.Map                                          //proxy连接缓存池
	token        string                                            //握手登录校验token
	proxyMap     sync.Map                                          //具体的业务连接缓存池
	ReverseFunc  func(proxyId, ticket, token string) (bool, error) //创建链接回调函数
	FuncLog      func(msg string)                                  //日志打印函数
	Lock         sync.Mutex
}
type ServerConfig struct {
	BindIp               string //内网穿透时，使用的内网本地IP
	BindPort             int    //服务监听端口
	Token                string
	TcpBufSize           int //TCP转发的缓冲区大小 默认128K
	MaxIdealTimeOut      int //Tcp业务连接最大空闲时间，默认24小时
	AssetMaxIdealTimeOut int //Proxy控制连接最大空闲时间，默认3分钟
}

func NewServer(config *ServerConfig, reverse func(proxyId, ticket, token string) (bool, error)) *VpcServer {
	if config.TcpBufSize == 0 {
		config.TcpBufSize = DefaultTcpBufSize
	}
	if config.MaxIdealTimeOut == 0 {
		config.MaxIdealTimeOut = DefaultMaxIdealTimeOut
	}
	if config.AssetMaxIdealTimeOut == 0 {
		config.AssetMaxIdealTimeOut = DefaultAssetMaxIdealTimeOut
	}
	return &VpcServer{
		config:       config,
		token:        config.Token,
		ReverseFunc:  reverse,
		ticketsCache: cache.New(5*time.Minute, 10*time.Minute),
	}
}

// NewTicket 生成一张新的票据，用户proxy登录时的验证
func (server *VpcServer) newTicket(timeout int) string {
	//修改ticket生成规则，增加本机标识
	hostName := GetHostName()
	id := hostName + "@" + string(randUp(24))
	server.ticketsCache.Set(id, id, time.Second*time.Duration(timeout))
	return id
}
func (server *VpcServer) StartTLSServer(crt, key string) {
	go func() {
		cert, err := tls.LoadX509KeyPair(crt, key)
		if err != nil {
			log.Panic(err)
			return
		}
		config := &tls.Config{Certificates: []tls.Certificate{cert}}
		addr := fmt.Sprintf("%s:%s", server.config.BindIp, strconv.Itoa(server.config.BindPort))
		tcpListener, err := tls.Listen("tcp", addr, config)
		if err != nil {
			log.Panic(err)
		}
		for {
			remoteConn, err := tcpListener.Accept()
			if err != nil {
				continue
			}
			go server.handlerProxyConn(remoteConn)
		}
	}()
	go server.clearUnusedProxy()
	go server.clearUnusedAssetProxy()
}

// 处理vpc客户端连接信息
func (server *VpcServer) handlerProxyConn(proxyConn net.Conn) {
	proxyId := ""
	ctx, cancel := context.WithCancel(context.Background())
	defer func() {
		if e := recover(); e != nil {
			server.PrintLog(fmt.Sprintf("handlerProxyConn Panicing %s\n", e))
		}
		_ = proxyConn.Close()
		cancel()
	}()
	for {
		payload, err := receive(proxyConn)
		if err != nil {
			if proxyId != "" {
				server.vpcConnMap.Delete(proxyId)
				server.PrintLog("receive data err close the tunnel！ProxyId:" + proxyId + " RemoteAddr" + proxyConn.RemoteAddr().String() + "  " + err.Error())
			}
			return
		}
		if payload.command[0] == VpcAgentRequest[0] {
			rst := &AgentDto{}
			err = json.Unmarshal(payload.head, rst)
			if err != nil {
				server.PrintLog("serialize error: " + err.Error())
				return
			}
			switch rst.Action {
			case Register:
				dto := &AgentDto{
					Action:  Register,
					Success: false,
				}
				marshal, _ := json.Marshal(rst)
				server.PrintLog("the vpc client register info: " + string(marshal))
				if rst.Token != server.token {
					server.PrintLog("the token is wrong! " + rst.ProxyId)
					return
				}
				_, b := server.ticketsCache.Get(rst.Ticket)
				if b {
					dto.Success = true
					server.ticketsCache.Delete(rst.Ticket)
				} else {
					server.PrintLog("the ticket is nil! " + rst.ProxyId)
					return
				}
				//判断之前下是否已经连接成功
				_, ok := server.vpcConnMap.Load(rst.ProxyId)
				if ok {
					return
				}
				vpcConnWrapper := &RemoteConnWrapper{RemoteConn: proxyConn, ProxyId: rst.ProxyId, ioChan: make(chan []byte, 100)}
				vpcConnWrapper.lastHeartBeatTime = time.Now()
				server.vpcConnMap.Store(rst.ProxyId, vpcConnWrapper)
				server.PrintLog("Store vpcConn success！proxyId: " + rst.ProxyId)
				proxyId = rst.ProxyId
				go server.SendMessage(ctx, vpcConnWrapper)
				head, _ := json.Marshal(dto)
				_, _ = proxyConn.Write(PayLoadEncode(VpcAgentResponse, head, nil))
				break
			case KeepAlive:
				//兼容老版本，老版本上报的心跳数据里面没有这个ID
				if rst.ProxyId == "" {
					rst.ProxyId = proxyId
				}
				conn, ok := server.vpcConnMap.Load(rst.ProxyId)
				if ok {
					connWrapper, okk := conn.(*RemoteConnWrapper)
					if okk {
						connWrapper.lastHeartBeatTime = time.Now()
						connWrapper.ioChan <- PayLoadEncode(VpcAgentResponse, payload.head, nil)
					}
				}
				break
				//数据上行
			case Data:
				proxy, ok := server.proxyMap.Load(rst.ProxyId + "@" + rst.AssetAddr)
				if !ok {
					continue
				}
				con, o := proxy.(*Proxy).ClientConnMap.Load(rst.ClientAddr)
				if !o {
					continue
				}
				connWrapper, ok := con.(*ClientConnWrapper)
				if ok {
					_, err = connWrapper.ClientConn.Write(payload.body)
					if err != nil {
						_ = connWrapper.ClientConn.Close()
					}
				}
				break
			case CloseConn:
				proxy, ok := server.proxyMap.Load(rst.ProxyId + "@" + rst.AssetAddr)
				if !ok {
					continue
				}
				con, o := proxy.(*Proxy).ClientConnMap.Load(rst.ClientAddr)
				if !o {
					continue
				}
				connWrapper, ok := con.(*ClientConnWrapper)
				if ok {
					_ = connWrapper.ClientConn.Close()
				}
				break
			}
		} else {
			//非法请求
			server.PrintLog("this is an illegal request,close the channel!")
			return
		}
	}
}

// GetOrCreateTunnel 根据客户端源地址打开隧道入口
func (server *VpcServer) GetOrCreateTunnel(proxyId, assetAddr string) (string, error) {
	remoteConn, ok := server.vpcConnMap.Load(proxyId)
	//如果没有该vpc的隧道，请求服务端，打开反穿隧道，该隧道在单个VPC内复用
	if !ok {
		_, err := server.ReverseFunc(proxyId, server.newTicket(180), server.config.Token)
		if err != nil {
			return "", errors.New(err.Error() + "ProxyId:" + proxyId)
		}
		time.Sleep(time.Duration(20) * time.Millisecond)
		remoteConn, ok = server.vpcConnMap.Load(proxyId)
		if !ok {
			time.Sleep(time.Duration(20) * time.Millisecond)
			remoteConn, ok = server.vpcConnMap.Load(proxyId)
			if !ok {
				server.PrintLog("can't find the Proxy，ProxyId:" + proxyId)
				return "nil", errors.New("can't find the Proxy，ProxyId:" + proxyId)
			}
		}
	}
	RemoteWrapper, ok := remoteConn.(*RemoteConnWrapper)
	if !ok {
		return "", errors.New("conversion the RemoteConnWrapper err ProxyId:" + proxyId)
	}
	proxy, okk := server.proxyMap.Load(proxyId + "@" + assetAddr)
	if !okk {
		//如果没有对应的资产代理，就创建一个
		listener, _ := net.Listen("tcp", "127.0.0.1:0")
		proxy := &Proxy{ProxyId: proxyId, RemoteConnWrapper: RemoteWrapper, Listener: listener, AssetAddr: assetAddr, TcpBufSize: server.config.TcpBufSize, ProxyMaxIdealTimeOut: server.config.MaxIdealTimeOut, AssetMaxIdealTimeOut: server.config.AssetMaxIdealTimeOut}
		server.proxyMap.Store(proxyId+"@"+assetAddr, proxy)
		go proxy.run()
		return listener.Addr().String(), nil
	}
	//两边地址不一样，说明Proxy已经断开重连过了！必须进行清理重开
	if RemoteWrapper.RemoteConn.RemoteAddr() != proxy.(*Proxy).RemoteConnWrapper.RemoteConn.RemoteAddr() {
		proxy.(*Proxy).stop()
		_ = proxy.(*Proxy).RemoteConnWrapper.RemoteConn.Close()
		server.proxyMap.Delete(proxy.(*Proxy).ProxyId + "@" + proxy.(*Proxy).AssetAddr)
		server.PrintLog("the proxy has reconnect！ProxyID：" + proxyId + " asset：" + assetAddr)
		listener, _ := net.Listen("tcp", "127.0.0.1:0")
		newProxy := &Proxy{ProxyId: proxyId, RemoteConnWrapper: RemoteWrapper, Listener: listener, AssetAddr: assetAddr, TcpBufSize: server.config.TcpBufSize, ProxyMaxIdealTimeOut: server.config.MaxIdealTimeOut, AssetMaxIdealTimeOut: server.config.AssetMaxIdealTimeOut}
		server.proxyMap.Store(proxyId+"@"+assetAddr, newProxy)
		go newProxy.run()
		return listener.Addr().String(), nil
	}
	return proxy.(*Proxy).Listener.Addr().String(), nil
}

func (server *VpcServer) PrintLog(msg string) {
	if server.FuncLog != nil {
		server.FuncLog(msg)
	}
}
func (server *VpcServer) clearUnusedAssetProxy() {
	ticker := time.NewTicker(time.Minute * time.Duration(5))
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			//超过24小时没有新的链接建立的proxy，进行关闭
			server.proxyMap.Range(func(key, value any) bool {
				proxy, ok := value.(*Proxy)
				if ok {
					if time.Now().Sub(proxy.LastVisitTime) > time.Minute*time.Duration(proxy.ProxyMaxIdealTimeOut) {
						proxy.stop()
						server.proxyMap.Delete(key.(string))
					}
				}
				return true
			})
		}
	}
}

func (server *VpcServer) clearUnusedProxy() {
	ticker := time.NewTicker(HeartbeatInterval)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			server.vpcConnMap.Range(func(key, value any) bool {
				proxy, ok := value.(*RemoteConnWrapper)
				if ok {
					if time.Since(proxy.lastHeartBeatTime) > time.Minute*time.Duration(3) {
						proxy.RemoteConn.Close()
						server.vpcConnMap.Delete(key.(string))
					}
				}
				return true
			})
		}
	}
}

// SendMessage 处理隧道转发逻辑
func (server *VpcServer) SendMessage(ctx context.Context, remoteConnWrapper *RemoteConnWrapper) {
	for {
		select {
		case <-ctx.Done():
			return
		case msg, ok := <-remoteConnWrapper.ioChan:
			if ok {
				_, err := remoteConnWrapper.RemoteConn.Write(msg)
				if err != nil {
					return
				}
			}
		}
	}
}

func (server *VpcServer) FindVpcConnById(proxyId string) (net.Conn, error) {
	if proxyId == "" {
		return nil, errors.New("proxyId can't be none")
	}
	value, ok := server.vpcConnMap.Load(proxyId)
	if ok {
		wrapper, okk := value.(*RemoteConnWrapper)
		if okk {
			return wrapper.RemoteConn, nil
		}
	}
	return nil, errors.New("vpcConn is none")
}
